import os

import torch

project_path = "/Users/ray/Downloads/项目/KeyPointUnet"
dataset_path = "/Users/ray/Downloads/项目/dataMining/datasets/all"

index_dict = {
    "Rlong1": 0,
    "Rlong2": 1,
    "Rass1": 2,
    "Rass2": 3,
    "Rass3": 4,
    "Rass4": 5,
    "TiR": 6,
    "Llong1": 7,
    "Llong2": 8,
    "Lass1": 9,
    "Lass2": 10,
    "Lass3": 11,
    "Lass4": 12,
    "TiL": 13
}


train_config = {
    # 网络训练部分
    # 'device': torch.device("mps"),
    'device': torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    'lr': 1e-4,
    'batch_size': 1,
    'epochs': 1000,
    'save_epoch': 2,
    # 网络评估部分
    'test_batch_size': 1,
}
